[fix][train] Prompt-based mini-batching for step-wise training#1529
Merged
CharlieFRuan merged 5 commits intomainfrom Apr 17, 2026
Merged
[fix][train] Prompt-based mini-batching for step-wise training#1529CharlieFRuan merged 5 commits intomainfrom
CharlieFRuan merged 5 commits intomainfrom
Conversation
3 tasks
This was referenced Apr 17, 2026
CharlieFRuan
added a commit
that referenced
this pull request
Apr 19, 2026
Rebase PR #1479 onto current main (post-PRs #1507/#1526/#1527/#1529). The original E2E fix's `pad_batch` change is dropped since #1529's prompt-based mini-batch boundaries removed the need to pad to `mini_batch_size * n_samples`. - merge_stepwise_output() in trainer_utils.py collapses multi-turn step-wise GeneratorOutput sequences into single sequences when consecutive turns share a common prefix, reducing training cost from O(T^2) to O(T). - trainer.py: call merge before extracting generator fields, update uids from merged trajectory_ids, emit generate/num_seq_{before,after}_merge. - Add generator.merge_stepwise_output config flag. - run_search.sh: MERGE_STEPWISE env var. - 16 CPU-only tests covering all 3 merging cases, partial merges, and validation asserts. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This was referenced Apr 19, 2026
CharlieFRuan
added a commit
that referenced
this pull request
Apr 20, 2026
…1532) ## Summary This PR implements prefix-aware merging for step-wise training, guarded by a flag `cfg.generator.merge_stepwise_output` that defaults to False. During step-wise training, within a trajectory, when consecutive steps share the same prefix (i.e. no re-tokenization drift or context management like thinking token stripping), we collapse into a single `GeneratorOutput` entry. This can reduce the O(T²) training cost introduced by step-wise (T being number of turns). - `merge_stepwise_output()` in `skyrl/train/utils/trainer_utils.py` implements greedy merging: for consecutive turns in the same trajectory where `prompt[i] + response[i]` is a prefix of `prompt[i+1]`, merge into one entry. Response tokens concatenated with the observation-delta (loss-masked to 0) between turns; per-token fields (`loss_masks`, `rewards`, `rollout_logprobs`) align accordingly; per-turn fields (`stop_reason`, `is_last_step`, `trajectory_id`) take the last turn's value. - `RayPPOTrainer.postprocess_generator_output` calls `merge_stepwise_output` when `generator.merge_stepwise_output=true`, updates `uids` from the merged `trajectory_ids`, and logs `generate/num_seq_{before,after}_merge`. - Since `uids` may need to be modified, update the signature of `postprocess_generator_output` to return both `generator_output` and `uids`, changing various caller places - New `generator.merge_stepwise_output` config flag (default false). - `examples/train/search/run_search.sh` accepts `MERGE_STEPWISE=true` env var to pass the flag through. - 16 CPU-only unit tests in `tests/train/test_merge_stepwise_output.py` cover the three merge cases, partial merges, prefix mismatches, single-turn passthrough, per-trajectory scalar rewards, and required-field asserts. ## Test plan - [x] `pytest tests/train/test_merge_stepwise_output.py` — 16 passed - [x] `pytest tests/train/test_trainer_utils.py tests/train/test_prompt_mini_batch.py` — 58 passed (existing tests unaffected) - [x] E2E: Search-R1 step-wise GRPO run on Qwen2.5-3B-Instruct, 8×H100, `MERGE_STEPWISE=true`. ### Curves With pricesly the same setup as #1529 , we do: ```bash MERGE_STEPWISE=true USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh \ generator.inference_engine.num_engines=8 \ generator.inference_engine.tensor_parallel_size=1 ``` See PR description for more. Co-authored-by: Deep Sheth [deepsheth3@users.noreply.github.com](mailto:deepsheth3@users.noreply.github.com)
3 tasks
CharlieFRuan
added a commit
that referenced
this pull request
Apr 20, 2026
…1538) ## Summary This PR implements prefix-aware merging for step-wise training, guarded by a flag `cfg.generator.merge_stepwise_output` that defaults to False. During step-wise training, within a trajectory, when consecutive steps share the same prefix (i.e. no re-tokenization drift or context management like thinking token stripping), we collapse into a single `GeneratorOutput` entry. This can reduce the O(T²) training cost introduced by step-wise (T being number of turns). - `merge_stepwise_output()` in `skyrl/train/utils/trainer_utils.py` implements greedy merging: for consecutive turns in the same trajectory where `prompt[i] + response[i]` is a prefix of `prompt[i+1]`, merge into one entry. Response tokens concatenated with the observation-delta (loss-masked to 0) between turns; per-token fields (`loss_masks`, `rewards`, `rollout_logprobs`) align accordingly; per-turn fields (`stop_reason`, `is_last_step`, `trajectory_id`) take the last turn's value. - `RayPPOTrainer.postprocess_generator_output` calls `merge_stepwise_output` when `generator.merge_stepwise_output=true`, updates `uids` from the merged `trajectory_ids`, and logs `generate/num_seq_{before,after}_merge`. - Since `uids` may need to be modified, update the signature of `postprocess_generator_output` to return both `generator_output` and `uids`, changing various caller places - New `generator.merge_stepwise_output` config flag (default false). - `examples/train/search/run_search.sh` accepts `MERGE_STEPWISE=true` env var to pass the flag through. - 16 CPU-only unit tests in `tests/train/test_merge_stepwise_output.py` cover the three merge cases, partial merges, prefix mismatches, single-turn passthrough, per-trajectory scalar rewards, and required-field asserts. ## Test plan - [x] `pytest tests/train/test_merge_stepwise_output.py` — 16 passed - [x] `pytest tests/train/test_trainer_utils.py tests/train/test_prompt_mini_batch.py` — 58 passed (existing tests unaffected) - [x] E2E: Search-R1 step-wise GRPO run on Qwen2.5-3B-Instruct, 8×H100, `MERGE_STEPWISE=true`. ### Curves With pricesly the same setup as #1529 , we do: ```bash MERGE_STEPWISE=true USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh \ generator.inference_engine.num_engines=8 \ generator.inference_engine.tensor_parallel_size=1 ``` See PR description for more. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Deep Sheth <deepsheth3@users.noreply.github.com>
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Step-wise training decomposes multi-turn trajectories into one training sequence per LLM turn, producing a variable number of sequences per prompt. This broke the old fixed-size mini-batching in two ways:
This PR shifts mini-batching from sequence units to prompt units. Each mini-batch now contains sequences for exactly
policy_mini_batch_sizeprompts, regardless of how many sequences those prompts generated. This ensures the number of optimizer steps is alwaystrain_batch_size / policy_mini_batch_size * update_epochs_per_batch.The overhead should be minimal, since the number of padded sequence is capped at
dp_sizefor each mini batch.Key changes
compute_prompt_mini_batch_boundaries()(skyrl/train/dataset/preprocess.py): walks a flatuidslist, detects prompt boundaries by consecutive-equal groups, and slices them into(start, end)boundary pairs for each mini-batch. Asserts uid contiguity (a uid cannot re-appear after a gap). Assertslen(unique_uids) == train_batch_size. For non-step-wise, asserts boundaries are uniform (backward compatible).MeshDispatch.stage_chunks()(dispatch.py): acceptsmini_batch_boundariesinstead of computing fixed-size chunks. Each mini-batch is individually padded todp_sizeusingpad_training_input_batch()._normalize_advantages()and_execute_training_step()(trainer.py): iterate over boundary pairs instead of fixed-size slicing.apply_loss_reduction_to_advantages_minibatch()(ppo_utils.py): will not supporttoken_mean_legacyfor now sincenum_micro_batchesdepend on how it is paddedWorkerDispatch.stage_data()(worker_dispatch.py): passes boundaries through tostage_chunks.Backward compatibility
For non-step-wise training, where each prompt has exactly
n_samples_per_promptsequences, boundaries remain uniform — identical to the original fixed-size slicing. An assertion incompute_prompt_mini_batch_boundariesverifies this.Test plan
tests/train/test_prompt_mini_batch.py: unit tests forcompute_prompt_mini_batch_boundaries(non-step-wise, step-wise, contiguity assertion, boundary uniformity parametrized),MeshDispatch.stage_chunks(padding, loss_mask zeros, variable sizes), and optimizer step count invariance.skyrl-search-padding.Search-r1 Curves
Report link: https://api.wandb.ai/links/sky-posttraining-uc-berkeley/c43eauat
1. Comparing non-stepwise is not affected before and after this PR, and step-wise vs. non-step-wise for this PR
token_meanloss reduction changed after this PR: [BREAKING][skyrl-train] Implement loss reduction via advantage normalization and fixtoken_meanreduction strategy #12962. Step-wise across PRs
Analysis:
num_mini_batches = len(data) // mini_batch_sizeand slice it with this num_mini_batchesdispatch_from_stagedto individually serialize DP chunks to avoid materializing whole batch on all workers #1376policy_mini_batch_size * n_samples_per_prompt, which is way too much paddingCommands used
On 8xH100s.
Non-stepwise:
Step-wise:
🤖 Generated with Claude Code